// Copyright 2024 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

syntax = "proto2";

package eeg_modelling.protos;

import "google/protobuf/timestamp.proto";


// Time of event. If an event does not have an end time, then only the start
// time is filled.
message EventTime {
  // Start time of event, relative to start of segment or chunk.
  optional float start_time_sec = 1;
  // End time of event, relative to start of segment or chunk.
  optional float end_time_sec = 2;
}

enum AttributionType {
  UNSPECIFIED = 0;
  VANILLA = 1;
  INTGRAD = 2;
  SMOOTHGRAD = 3;
}

enum DistanceType {
  UNKNOWN = 0;
  COSINE = 1;
  L2 = 2;
  DOT = 3;
}

message KNN {
  // Keys of nearest neighbors
  repeated string keys = 1;
  // Distances to nearest neighbors, in ascending order
  repeated float distances = 2;
  // Labels of nearest neighbors
  repeated float labels = 3;
  // Distance metric
  optional DistanceType distance_type = 4;
  // Table where inputs were taken from
  optional string input_location = 5;
  // Table where neighbors were taken from
  optional string candidate_location = 6;
}

// Coefficients for handling loose electrode artifacts.
message ArtifactCoefs {
  optional string electrode = 1;
  optional float phase_reversal_coef = 2;
  optional float no_repetition_coef = 3;
}

message Label {
  // Name of label.
  optional string name = 1;

  message Value {
    // Score of label. This is a real number.
    optional float score = 1;
    // Probability of label. This is in (0,1).
    optional float probability = 2;
    // Event time for event based predictions. If used, only the start time
    // may be filled in.
    optional EventTime time = 3;
  }

  message AttributionMap {
    // Length = height x width, and in row-major order.
    repeated float attribution = 1;
    optional int32 width = 2;
    optional int32 height = 3;
    repeated string feature_names = 4;
    optional AttributionType attribution_type = 5;
  }

  // Model's prediction for the given label.
  optional Value predicted_value = 2;
  // Actual value for label if it exists.
  optional Value actual_value = 3;
  // Attribution map for this label.
  optional AttributionMap attribution_map = 4;
}

message SegmentInfo {
  // Identifier for segment.
  optional string segment_id = 1;
}

message ChunkInfo {
  // Identifier for chunk.
  optional string chunk_id = 1;
  // Absolute start time of chunk.
  optional google.protobuf.Timestamp chunk_start_time = 2;
  // Size of chunk in seconds.
  optional float chunk_size_sec = 3;
}

message PredictionOutput {
  optional string patient_id = 1;
  optional ChunkInfo chunk_info = 2;
  optional SegmentInfo segment_info = 3;
  repeated Label label = 4;
  // Identifier for the model used to generate the prediction output.
  optional string model_id = 5;
  // Features at representation layer
  repeated float representation = 6;
  // Information about nearest neighbors
  optional KNN knn = 7;
  // Coeficients to reject artifacts.
  optional ArtifactCoefs artifact_coefs = 8;
}

// List of prediction outputs. Some examples include predictions for all chunks
// for a given segment, a given 24-hour time frame, or a given patient.
message PredictionOutputs {
  repeated PredictionOutput prediction_output = 1;
}
